Bootstrapping model fits¶
The previous section describes fitting a single model. But we may also want to have confidence estimates for the fit. We can do that via bootstrapping the data set.
The overall recommended workflow is to first fit models to all the data to determine the number of epitopes, etc. Then once the desired fitting parameters are determined, you can bootstrap to get confidence on predictions.
Get model fit to the data¶
The first step is just to fit a Polyclonal model to all the data we are using. We do similar to the previous notebook for our RBD example, but first shrink the size of the data set to just 7500 variants to provide more “error” to better illustrate the bootstrapping.
We will call this model fit to all the data we are using the “root” model as it’s used as the starting point (root) for the subsequent bootstrapping. Note that data (which we will bootstrap) are attached to this pre-fit model:
[1]:
# NBVAL_IGNORE_OUTPUT
import pandas as pd
import polyclonal
# read the data, and just make "barcode" the numerical rank of the variants
noisy_data = (
pd.read_csv("RBD_variants_escape_noisy.csv", na_filter=None)
.query('library == "avg3muts"')
.query("concentration in [0.25, 1, 4]")
.sort_values(["concentration", "aa_substitutions"])
.reset_index(drop=True)
.assign(barcode=lambda x: x.groupby("concentration").cumcount())
)
# just keep some variants to make fitting "noisier"
n_keep = 7500
barcodes_to_keep = (
noisy_data["barcode"].drop_duplicates().sample(n_keep, random_state=1).tolist()
)
noisy_data = noisy_data.query("barcode in @barcodes_to_keep")
# make and fit the root Polyclonal object with all the data we are using
root_poly = polyclonal.Polyclonal(
data_to_fit=noisy_data,
activity_wt_df=pd.DataFrame.from_records(
[
("class 1", 1.0),
("class 2", 3.0),
("class 3", 2.0),
],
columns=["epitope", "activity"],
),
site_escape_df=pd.DataFrame.from_records(
[
("class 1", 417, 10.0),
("class 2", 484, 10.0),
("class 3", 444, 10.0),
],
columns=["epitope", "site", "escape"],
),
data_mut_escape_overlap="fill_to_data",
)
opt_res = root_poly.fit(logfreq=100)
# First fitting site-level model.
# Starting optimization of 522 parameters at Sat Mar 19 09:34:43 2022.
step time_sec loss fit_loss reg_escape regspread
0 0.022838 4506 4505.7 0.29701 0
100 2.5614 550.09 546.34 3.7432 0
200 4.9081 541.57 537.02 4.5554 0
300 7.272 539 533.84 5.1659 0
400 9.5925 538.27 532.9 5.3674 0
500 11.941 537.67 532.23 5.4371 0
600 14.214 537.13 531.63 5.5078 0
700 16.518 536.64 530.85 5.7896 0
800 18.892 536.31 530.4 5.9137 0
900 21.26 536.06 530.11 5.956 0
1000 23.576 535.51 529.4 6.1068 0
1100 25.904 535.05 528.9 6.1524 0
1200 28.217 534.73 528.47 6.2531 0
1300 30.477 534.43 528.04 6.3902 0
1400 32.749 534.36 527.89 6.4741 0
1500 35.087 534.33 527.77 6.5579 0
1558 36.507 534.32 527.71 6.6146 0
# Successfully finished at Sat Mar 19 09:35:20 2022.
# Starting optimization of 5799 parameters at Sat Mar 19 09:35:20 2022.
step time_sec loss fit_loss reg_escape regspread
0 0.027019 643.27 566.6 76.667 1.589e-29
100 2.8415 323.66 237.57 75.125 10.97
200 5.609 310.33 223.52 70.913 15.9
300 8.423 300.99 220.06 64.13 16.8
400 11.172 286.66 215.39 52.955 18.31
500 13.872 276.87 209.89 47.828 19.156
600 16.618 271.09 205.02 45.623 20.445
700 19.431 265.34 199.01 44.579 21.751
800 22.216 260.79 192.97 44.793 23.031
900 24.935 258.05 188.63 45.423 24.004
1000 27.775 255.18 183.5 46.481 25.201
1100 30.532 252.88 179.9 47.376 25.609
1200 33.198 251.35 177.86 47.671 25.813
1300 35.91 249.55 174.99 48.19 26.365
1400 38.565 247.69 172.67 48.51 26.513
1500 41.315 246.33 171.11 48.751 26.472
1600 43.945 245.85 170.46 48.93 26.455
1700 46.531 245.54 169.81 49.079 26.647
1800 49.215 245.09 169.05 49.252 26.788
1900 52.001 244.8 168.4 49.421 26.982
2000 54.759 244.57 168.41 49.371 26.794
2100 57.55 244.4 168.07 49.405 26.926
2200 60.373 244.25 167.91 49.429 26.915
2300 63.072 244.16 167.75 49.47 26.941
2400 65.807 244.06 167.59 49.501 26.979
2500 68.558 243.94 167.38 49.51 27.046
2600 71.229 243.83 167.26 49.516 27.049
2700 74.031 243.75 167.16 49.522 27.07
2800 76.775 243.7 167.13 49.52 27.051
2900 79.589 243.67 167.04 49.52 27.112
3000 82.234 243.61 167.06 49.506 27.044
3100 84.976 243.57 167.08 49.505 26.99
3200 87.754 243.53 167.02 49.545 26.973
3300 90.538 243.5 167.04 49.555 26.906
3400 93.249 243.48 167.01 49.575 26.897
3500 95.987 243.45 166.99 49.587 26.87
3596 98.538 243.42 166.97 49.586 26.866
# Successfully finished at Sat Mar 19 09:36:59 2022.
Create and fit bootstrapped models¶
To create the bootstrapped models, we initialize a PolyclonalCollection, here just using 5 samples for speed (for real analyses to get good error estimates you may want more on the order of 20 to 100 bootstrap samples). Note it is important that the root model you are using has already been fit to the data! Note also that there is a n_threads option which specifies how many threads should be used for the bootstrapping: by default it’s -1 (use all CPUs available), but set to another
number if you want to limit CPU usage:
[2]:
n_bootstrap_samples = 5
bootstrap_poly = polyclonal.PolyclonalCollection(
root_polyclonal=root_poly,
n_bootstrap_samples=n_bootstrap_samples,
)
Now fit the bootstrapped models:
[3]:
# NBVAL_IGNORE_OUTPUT
import time
start = time.time()
print(f"Starting fitting bootstrap models at {time.asctime()}")
n_fit, n_failed = bootstrap_poly.fit_models()
print(f"Fitting took {time.time() - start:.3g} seconds, finished at {time.asctime()}")
assert n_failed == 0 and n_fit == n_bootstrap_samples
Starting fitting bootstrap models at Sat Mar 19 09:37:00 2022
Fitting took 65.2 seconds, finished at Sat Mar 19 09:38:06 2022
Look at summarized results¶
We can get the resulting measurements for the epitope activities and mutation effects both per-replicate and summarized across replicates (mean, median, standard deviation).
Epitope activities¶
Epitope activities for each replicate:
[4]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.activity_wt_df_replicates.round(1)
[4]:
| epitope | activity | bootstrap_replicate | |
|---|---|---|---|
| 0 | class 1 | 2.0 | 1 |
| 1 | class 2 | 2.6 | 1 |
| 2 | class 3 | 2.1 | 1 |
| 3 | class 1 | 1.9 | 2 |
| 4 | class 2 | 2.7 | 2 |
| 5 | class 3 | 1.9 | 2 |
| 6 | class 1 | 2.1 | 3 |
| 7 | class 2 | 2.5 | 3 |
| 8 | class 3 | 2.0 | 3 |
| 9 | class 1 | 1.9 | 4 |
| 10 | class 2 | 2.7 | 4 |
| 11 | class 3 | 2.0 | 4 |
| 12 | class 1 | 2.1 | 5 |
| 13 | class 2 | 2.6 | 5 |
| 14 | class 3 | 1.9 | 5 |
Epitope activities summarized across replicates. The std column gives the standard deviation:
[5]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.activity_wt_df.round(1)
[5]:
| epitope | mean | median | std | |
|---|---|---|---|---|
| 0 | class 1 | 2.0 | 2.0 | 0.1 |
| 1 | class 2 | 2.6 | 2.6 | 0.1 |
| 2 | class 3 | 2.0 | 2.0 | 0.1 |
We can plot the epitope activities summarized across replicates. The dropdown allows you to choose the summary stat (mean, median), and the black lines indicate the standard deviation. Mouse over for values:
[6]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.activity_wt_barplot()
[6]:
Mutation escape values¶
Mutation escape values for each replicate:
[7]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_df_replicates.round(1).head()
[7]:
| epitope | site | wildtype | mutant | mutation | escape | bootstrap_replicate | |
|---|---|---|---|---|---|---|---|
| 0 | class 1 | 331 | N | A | N331A | 0.4 | 1 |
| 1 | class 1 | 331 | N | D | N331D | -0.4 | 1 |
| 2 | class 1 | 331 | N | E | N331E | 0.3 | 1 |
| 3 | class 1 | 331 | N | F | N331F | 0.1 | 1 |
| 4 | class 1 | 331 | N | G | N331G | 0.2 | 1 |
Mutation escape values summarizes across replicates. Note the frac_bootstrap_replicates column has the fraction of bootstrap replicates with a value for this mutation:
[8]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_df.round(1).head(n=3)
[8]:
| epitope | site | wildtype | mutant | mutation | mean | median | std | n_bootstrap_replicates | frac_bootstrap_replicates | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | class 1 | 331 | N | A | N331A | 0.2 | 0.2 | 0.4 | 5 | 1.0 |
| 1 | class 1 | 331 | N | D | N331D | -0.2 | -0.1 | 0.2 | 5 | 1.0 |
| 2 | class 1 | 331 | N | E | N331E | -0.0 | 0.0 | 0.3 | 5 | 1.0 |
We can plot the mutation escape values across replicates. The dropdown selects the statistic shown in the heatmap (mean or median), and mouseovers give details on points. Here we set min_frac_bootstrap_replicates=0.9 to only report escape values observed in at least 90% of bootstrap replicates (this gets rid of rare mutations):
[9]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_heatmap(min_frac_bootstrap_replicates=0.9)
[9]:
Site summaries of mutation escape¶
Site summaries of mutation escape values for replicates:
[10]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_site_summary_df_replicates.round(1).head()
[10]:
| epitope | site | wildtype | mean | total positive | max | min | total negative | bootstrap_replicate | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | class 1 | 331 | N | 0.5 | 8.9 | 1.8 | -0.7 | -1.3 | 1 |
| 1 | class 1 | 332 | I | 0.6 | 10.6 | 1.5 | 0.0 | 0.0 | 1 |
| 2 | class 1 | 333 | T | 0.5 | 9.4 | 1.3 | -0.7 | -0.9 | 1 |
| 3 | class 1 | 334 | N | 0.8 | 13.9 | 1.9 | -0.2 | -0.3 | 1 |
| 4 | class 1 | 335 | L | 0.5 | 9.5 | 1.5 | -0.5 | -0.8 | 1 |
Site summaries of mutation escape values summarized (e.g., averaged) across replicates. Note that the metric column now indicates a different row for each site-summary metric type, which is then summarized by its mean, median, and standard deviation:
[11]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_site_summary_df.round(1).head()
[11]:
| epitope | site | wildtype | metric | mean | median | std | n_bootstrap_replicates | frac_bootstrap_replicates | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | class 1 | 331 | N | max | 1.6 | 1.7 | 0.2 | 5 | 1.0 |
| 1 | class 1 | 331 | N | mean | 0.5 | 0.5 | 0.2 | 5 | 1.0 |
| 2 | class 1 | 331 | N | min | -0.5 | -0.5 | 0.2 | 5 | 1.0 |
| 3 | class 1 | 331 | N | total negative | -1.1 | -1.3 | 0.7 | 5 | 1.0 |
| 4 | class 1 | 331 | N | total positive | 8.4 | 8.9 | 2.0 | 5 | 1.0 |
We can plot site summaries of the mutation escape. Note that there is an option to toggle on/off the error bars (standard deviations) and show what metric is shown (e.g., mean effect of mutation, total positive escape at site, etc) as well as how that metric is summarize (mean, median):
[12]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_lineplot(min_frac_bootstrap_replicates=0.9)
[12]:
Some tests¶
Below are just tests for approximate consistency of results with what is expected:
[13]:
sites = [417, 446, 484, 501] # just test these sites
for attr, atol in [
("activity_wt_df", 0.5),
("mut_escape_site_summary_df", 1.0),
("mut_escape_df", 1.0),
]:
print(f"Testing {attr}")
df = getattr(bootstrap_poly, attr).round(1).drop(columns="std")
if "site" in df.columns:
df = df.query("site in @sites").reset_index(drop=True)
f = f"RBD_bootstrap_expected_{attr}.csv"
expected = pd.read_csv(f).drop(columns="std")
pd.testing.assert_frame_equal(
df,
expected,
atol=atol,
rtol=0.2,
obj=f"{attr} DataFrame",
)
Testing activity_wt_df
Testing mut_escape_site_summary_df
Testing mut_escape_df
[ ]: